import sys
sys.path.append('..')

from src.parsing.Parse import parse_file
from src.parsing.Ast import Maximize, Minimize, AssertSoft
from src.parsing.Types import BITVECTOR_TYPE, INTEGER_TYPE, REAL_TYPE

class opt():
    def __init__(self, file):
        self.script, self.variables = parse_file(file)

    def add_maximize(self, var):
        self.script.commands.append(Maximize(var))

    def add_minimize(self, var):
        self.script.commands.append(Minimize(var))

    def add_soft(self, var, weight):
        self.script.commands.append(AssertSoft(var, weight))


def test_optimize(smt_file):
    o = opt(smt_file)
    for v in o.variables.keys():
        if o.variables[v] in [INTEGER_TYPE, REAL_TYPE]:
            o.add_maximize(v)
            o.add_minimize(v)
        if isinstance(o.variables[v], BITVECTOR_TYPE):
            o.add_maximize(v)
            o.add_minimize(v)
    with open(smt_file, 'w') as f:
        f.write("(set-option :produce-models true)\n" + str(o.script) + '\n(check-sat)\n(get-objectives)\n')

def compare_res(res_file1, res_file2):
    """
    the format of the result file is:
    sat

    (objectives
    (n 1)
    )
    this function compares the values of the objectives
    """
    with open(res_file1, 'r') as f:
        res1 = f.read()
    with open(res_file2, 'r') as f:
        res2 = f.read()
    if "(objectives" in res1:
        # sat_res1, opt_res1 = res1.split("(objectives")
        sat_res1, opt_res1 = res1.split("(objectives")[0], res1.split("(objectives")[1]

        sat_res1 = sat_res1.split('\n')
        sat_res1.remove('')
        opt_res1 = opt_res1.split('\n')
        opt_res1.remove('')
        # print(sat_res1)
        # print(opt_res1)
    else:
        sat_res1 = res1.split('\n')
        sat_res1.remove('')
        opt_res1 = ''

    if "(objectives" in res2:
        sat_res2, opt_res2 = res2.split("(objectives")[0], res2.split("(objectives")[1]
        sat_res2 = sat_res2.split('\n')
        sat_res2.remove('')
        opt_res2 = opt_res2.split('\n')
        opt_res2.remove('')
    else:
        sat_res2 = res2.split('\n')
        sat_res2.remove('')
        opt_res2 = ''
    
    if len(sat_res1) != len(sat_res2):
        iter_num = min(len(sat_res1), len(sat_res2))
    else:
        iter_num = len(sat_res1)
    for i in range(iter_num):
        if sat_res1[i] != sat_res2[i] and sat_res1[i].strip() in ['unsat', 'sat'] and sat_res2[i].strip() in ['unsat', 'sat']:
            return "soundness"
    
    if opt_res1 != "" and opt_res2 != "":
        print(opt_res1)
        print(opt_res2)

        opt_res1_dict = {}
        opt_res2_dict = {}
        for i in range(len(opt_res1)):
            # if opt_res1[i].strip([' ', '(', ')']) != opt_res2[i].strip([' ', '(', ')']):
            if " " in opt_res1[i].strip():
                obj_key1, obj_val1 = opt_res1[i].strip().split(' ', 1)
                opt_res1_dict[obj_key1] = obj_val1
            if " " in opt_res2[i].strip():
                obj_key2, obj_val2 = opt_res2[i].strip().split(' ', 1)
                opt_res2_dict[obj_key2] = obj_val2
        for k in opt_res1_dict.keys():
            if opt_res1_dict[k] != opt_res2_dict[k]:
                return "optimization"
            
    return None
    
                


if __name__ == '__main__':
    # o = opt('../test/test1.smt2')
    # print(o.variables)
    # for v in o.variables.keys():
    #     if o.variables[v] in [INTEGER_TYPE, REAL_TYPE]:
    #         o.add_maximize(v)
    #         o.add_minimize(v)
    #     if isinstance(o.variables[v], BITVECTOR_TYPE):
    #         o.add_maximize(v)
    #         o.add_minimize(v)

    print(compare_res('../test/test_opt1.txt', '../test/test_opt2.txt'))
    
    # # print(o.script)
    # with open('test1_opt.smt2', 'w') as f:
    #     f.write("(set-option :produce-models true)\n" + str(o.script) + '\n(check-sat)\n(get-objectives)\n')

    
